import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import sounddevice as sd
import threading
from scipy.special import jv, jn_zeros

# -------------------------------
# Synthetic training data
# -------------------------------
frequencies = np.linspace(100, 1000, 100)
nodal_counts = 3 * np.sin(0.01 * frequencies) + 0.5 * np.cos(0.005 * frequencies) + np.random.normal(0, 0.05, len(frequencies))

X = np.vstack([
    frequencies,
    frequencies**2,
    np.sin(frequencies/100),
    np.cos(frequencies/200),
    np.power(1.618, frequencies/500),
    np.sqrt(frequencies),
    np.log1p(frequencies)
]).T

Y = np.vstack([
    nodal_counts + 0.1 * np.random.randn(len(frequencies)),
    nodal_counts * 2 + 0.2 * np.random.randn(len(frequencies)),
    nodal_counts * 0.7 + 0.1 * np.random.randn(len(frequencies)),
    nodal_counts * 1.1 + 0.2 * np.random.randn(len(frequencies))
]).T

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
model = LinearRegression().fit(X_train, Y_train)

print("Learned T matrix:\n", model.coef_)
print("Intercept:\n", model.intercept_)
print("R^2 test score:", model.score(X_test, Y_test))

# -------------------------------
# φ-hash tuning
# -------------------------------
phi = (1 + 5**0.5) / 2.0

def safe_phi_hash(x, phi=1.6180339887):
    x_mod = np.mod(x, 1000)
    return np.mod(np.power(phi, x_mod), 1.0)

def conditional_phi_tune(y_pred, x, y_true, lam=0.01):
    hash_vec = safe_phi_hash(x)
    hash_mapped = np.ones_like(y_pred) * np.sum(hash_vec) / len(hash_vec)
    y_candidate = y_pred + lam * hash_mapped
    if np.linalg.norm(y_candidate - y_true) < np.linalg.norm(y_pred - y_true):
        return y_candidate
    return y_pred

# -------------------------------
# Cymatic params
# -------------------------------
def get_cymatic_params(f_query):
    features = np.array([
        f_query,
        f_query**2,
        np.sin(f_query/100),
        np.cos(f_query/200),
        np.power(1.618, f_query/500),
        np.sqrt(f_query),
        np.log1p(f_query)
    ]).reshape(1, -1)
    params = model.predict(features)[0]
    y_true_ref = Y_test.mean(axis=0)
    params_tuned = conditional_phi_tune(params, features.flatten(), y_true_ref)
    return {
        "alpha": abs(params_tuned[0]),
        "beta": abs(params_tuned[1]),
        "eta": abs(params_tuned[2]),
        "zeta": abs(params_tuned[3])
    }

# -------------------------------
# Pattern generation
# -------------------------------
def log_phi(x: float) -> float:
    return np.log(x + 1e-12) / np.log(phi)

def choose_modes(f: float, f0: float = 110.0, alpha=1.0, beta=1.0/phi, gamma=0.0, max_n=4, max_m=4, top_k=4):
    E = log_phi(max(f, 1e-12) / f0)
    candidates = []
    for n in range(0, max_n+1):
        for m in range(0, max_m+1):
            score = abs(alpha * n + beta * m + gamma - E)
            candidates.append(((n, m), score))
    candidates.sort(key=lambda x: x[1])
    return [c[0] for c in candidates[:top_k]]

def generate_cartesian(coords, params):
    Xc, Yc = coords
    X_norm = (Xc + 1) / 2
    Y_norm = (Yc + 1) / 2
    return (
        np.sin(params["alpha"] * np.pi * X_norm) * np.sin(params["beta"] * np.pi * Y_norm) +
        params["eta"] * np.cos(params["zeta"] * np.pi * (X_norm + Y_norm))
    )

def generate_polar(coords, f_query, params, plate_radius=1.0, inner_radius=0.3, radial_zeros_per_mode=2):
    Xc, Yc = coords
    R = np.sqrt(Xc**2 + Yc**2)
    Theta = np.arctan2(Yc, Xc)
    U = np.zeros_like(R)
    mask = (R <= plate_radius) & (R >= inner_radius)  # Toroidal mask

    modes = choose_modes(f_query)
    for (n, m) in modes:
        roots = jn_zeros(m, radial_zeros_per_mode)
        for j_idx, root in enumerate(roots):
            scale = phi ** (-j_idx)
            k = root / (plate_radius * scale + 1e-12)
            amp = phi ** (-(params["eta"] * n + params["zeta"] * m + 0.5 * j_idx))
            phase = (n + m + j_idx) * 0.37
            contribution = amp * jv(m, k * R) * np.cos(m * Theta + phase)
            contribution = contribution * mask
            U += contribution

    U = U * mask
    return U

# -------------------------------
# Audio playback
# -------------------------------
sample_rate = 44100
duration = 0.1

class AudioPlayer:
    def __init__(self):
        self.current_freq = None
        self.playing = False
        self.thread = None

    def play_tone(self, frequency):
        if self.current_freq != frequency or not self.playing:
            sd.stop()
            self.current_freq = frequency
            self.playing = True
            t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False, dtype=np.float32)
            waveform = (0.2 * np.sin(2 * np.pi * frequency * t)).astype(np.float32)
            sd.play(waveform, samplerate=sample_rate, loop=True)

    def stop(self):
        self.playing = False
        sd.stop()

audio_player = AudioPlayer()

def note_to_freq(note_val):
    return 220.0 * 2 ** (note_val / 12)

# -------------------------------
# Prepare grid
# -------------------------------
Nx, Ny = 50, 50  # Further reduced for performance
x = np.linspace(-1, 1, Nx)
y = np.linspace(-1, 1, Ny)
Xg, Yg = np.meshgrid(x, y)

# -------------------------------
# Plot + sliders
# -------------------------------
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
plt.subplots_adjust(bottom=0.3)

morph_slider_ax = plt.axes([0.15, 0.15, 0.65, 0.03])
morph_slider = Slider(morph_slider_ax, "Morph (Cart. ↔ Polar)", 0.0, 1.0, valinit=0.0)

note_slider_ax = plt.axes([0.15, 0.1, 0.65, 0.03])
note_slider = Slider(note_slider_ax, "Note", 0.0, 72.0, valinit=0.0)

# Zoom control
def on_scroll(event):
    if event.inaxes == ax:
        scale = 1.1 if event.button == 'up' else 0.9
        curr_xlim = ax.get_xlim3d()
        curr_ylim = ax.get_ylim3d()
        curr_zlim = ax.get_zlim3d()
        ax.set_xlim3d([x * scale for x in curr_xlim])
        ax.set_ylim3d([y * scale for y in curr_ylim])
        ax.set_zlim3d([z * scale for z in curr_zlim])
        fig.canvas.draw_idle()

fig.canvas.mpl_connect('scroll_event', on_scroll)

# Initial draw
params = get_cymatic_params(note_to_freq(note_slider.val))
Z_cart = generate_cartesian((Xg, Yg), params)
Z_polar = generate_polar((Xg, Yg), note_to_freq(note_slider.val), params)
Z = (1 - morph_slider.val) * Z_cart + morph_slider.val * Z_polar
surf = ax.plot_surface(Xg, Yg, Z, cmap='viridis', rstride=1, cstride=1, linewidth=0, antialiased=True)
ax.set_zlim(-2, 2)
ax.set_title(f"Note: {note_slider.val:.1f}, Freq: {note_to_freq(note_slider.val):.2f} Hz")
ax.set_box_aspect([1, 1, 0.5])

def update(val):
    ax.clear()
    morph_val = morph_slider.val
    note_val = note_slider.val
    f_query = note_to_freq(note_val)
    params = get_cymatic_params(f_query)

    Z_cart = generate_cartesian((Xg, Yg), params)
    Z_polar = generate_polar((Xg, Yg), f_query, params)
    Z = (1 - morph_val) * Z_cart + morph_val * Z_polar

    ax.plot_surface(Xg, Yg, Z, cmap='viridis', rstride=1, cstride=1, linewidth=0, antialiased=True)
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_zlim(-2, 2)
    ax.set_box_aspect([1, 1, 0.5])
    ax.set_title(f"Note: {note_val:.1f}, Freq: {f_query:.2f} Hz")
    fig.canvas.draw_idle()

    audio_player.play_tone(f_query)

morph_slider.on_changed(update)
note_slider.on_changed(update)

# Start initial audio
audio_player.play_tone(note_to_freq(note_slider.val))

plt.show()

# Ensure audio stops when closing
plt.close('all')
audio_player.stop()